{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2799a9f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "# dnn模型构建\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import TensorDataset, DataLoader, Dataset\n",
    "import torch.nn.functional as F\n",
    "import random\n",
    "import logging\n",
    "from imblearn.over_sampling import SMOTE\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',\n",
    "        datefmt='%m/%d/%Y %H:%M:%S',\n",
    "        level=logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "25ad14c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 固定随机数种子，确保实验的可重复性\n",
    "def set_seed(seed=42):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "set_seed(42)\n",
    "\n",
    "def get_fft_and_scaler(data, start=5192, end=8192):\n",
    "    data = np.fft.fft(data)\n",
    "    data = np.abs(data)\n",
    "    data = data/np.expand_dims(data.max(axis=1), axis=1)\n",
    "    return data[:, start:end]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f3bbb704",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.dnn = nn.Sequential(\n",
    "            nn.BatchNorm1d(300),\n",
    "            nn.Linear(300, 1024),\n",
    "            nn.BatchNorm1d(1024),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(p=0.2),\n",
    "            nn.Linear(1024, 256),\n",
    "            nn.BatchNorm1d(256),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(p=0.2),\n",
    "            nn.Linear(256, 10),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.dnn(x)\n",
    "        return F.softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "91fe729d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 搭建DNN模型\n",
    "def model_train(train_loader, model, optimizer, criterion, labels):\n",
    "    model.train()\n",
    "    train_total_acc = 0\n",
    "    train_loss = 0\n",
    "    for feature, label in train_loader:\n",
    "        feature = feature.to(device)\n",
    "        label = label.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        preds = model(feature)\n",
    "\n",
    "        loss = criterion(preds, label)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_total_acc += model(feature).argmax(dim=1).eq(label).sum().item()\n",
    "        train_loss += loss.item()\n",
    "\n",
    "        feature.cpu()\n",
    "        label.cpu()\n",
    "\n",
    "    print(\n",
    "        f'Training loss: {train_loss/len(train_loader):.4f}',\n",
    "        f'Training  acc: {train_total_acc/len(labels):.4f}',\n",
    "         )\n",
    "\n",
    "def predict(val_loader, model, criterion, labels):\n",
    "    model.eval()\n",
    "    val_total_acc = 0\n",
    "    val_loss = 0\n",
    "    for feature, label in val_loader:\n",
    "        feature = feature.to(device)\n",
    "        label = label.to(device)\n",
    "        preds = model(feature)\n",
    "        loss = criterion(preds, label)\n",
    "\n",
    "        val_total_acc += model(feature).argmax(dim=1).eq(label).sum().item()\n",
    "        val_loss += loss.item()\n",
    "\n",
    "        feature.cpu()\n",
    "        label.cpu()\n",
    "\n",
    "    print(\n",
    "        f'Val loss: {val_loss/len(val_loader):.4f}',\n",
    "        f'Val  acc:{val_total_acc/len(labels):.4f}'\n",
    "    )\n",
    "    return val_loss\n",
    "\n",
    "# 使用boost ing的想法，让神经网络学习错误的类别\n",
    "# boosting 的想法：训练后面几次，分类错误的数据\n",
    "# 重新定义一个新的分类器进行学习错误的数据, 保存这些模型的参数\n",
    "# 然后使用模型融合\n",
    "#print('Stage4: start training')\n",
    "# 这里设置boost 的num数，设置为多少就会训练多少个dnn模型\n",
    "#boost_epoch_num = 4\n",
    "#for boost_num in range(boost_epoch_num):\n",
    "#    # 分类器学习, 更新训练集\n",
    "#    boost_feature = []\n",
    "#    boost_label = []\n",
    "#    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "#    criterion = nn.CrossEntropyLoss()\n",
    "#    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)\n",
    "#    train_best = float('inf')\n",
    "#    best_model = None\n",
    "#    print('--'*8 + f'boost round: {boost_num}/{boost_epoch_num - 1}' + '--'*8)\n",
    "#    print(f'train shape: {train_tensor.shape}')\n",
    "#    print('--'*24)\n",
    "#    for epoch in range(epochs):\n",
    "#        print('='*20 + f' Epoch: {epoch} '+ '='*20)\n",
    "#        model_train(train_loader, model, optimizer, criterion=criterion, labels=y_train_tensor)\n",
    "#        loss = predict(val_loader, model, criterion=criterion, labels=val_label)\n",
    "#        if loss <= train_best:\n",
    "#            train_best = loss\n",
    "#            best_model = model\n",
    "#    # 模型保存\n",
    "#    torch.save(best_model.state_dict(), f'./best_model{str(boost_num)}.point')\n",
    "#    get_boost_data(train_tensor.to(device), y_train_tensor.to(device), model=best_model)\n",
    "#    # 开始boosting\n",
    "#    train_tensor = torch.cat(boost_feature, dim=0)\n",
    "#    y_train_tensor = torch.cat(boost_label, dim=0)\n",
    "\n",
    "#print('Stage5: model score')\n",
    "#for i in range(boost_epoch_num):\n",
    "#    model_name = f'./best_model{i}.point'\n",
    "#    # 重新初始化模型\n",
    "#    model = DNN().to(device)\n",
    "#    model.load_state_dict(torch.load(model_name))\n",
    "#    # 这里一定要开启验证模式\n",
    "#    model.eval()\n",
    "#    print('--'*24)\n",
    "#    print(f'Model name: {model_name}')\n",
    "#    # 这里只能用验证集来看准确率\n",
    "#    preds = model(torch.FloatTensor(val_sp).to(device)).argmax(dim=1).cpu().numpy()\n",
    "#    score = (preds == val_label).sum()/len(val_label)\n",
    "#    print(f'Score: {score}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "adde4ccc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stage1: load data\n",
      "Stage2: data over_sampling\n"
     ]
    }
   ],
   "source": [
    "print('Stage1: load data')\n",
    "# 读取训练集，测试集和验证集\n",
    "train = np.load('../train/10type_sort_train_data_8192.npy')\n",
    "val = np.load('../val/10type_sort_eval_data_8192.npy')\n",
    "\n",
    "# 读取训练集和验证集的标签，测试集是没有标签的，需要你使用模型进行分类，并将结果进行提交\n",
    "train_label = np.load('../train/10type_sort_train_label_8192.npy')\n",
    "val_label = np.load('../val/10type_sort_eval_label_8192.npy')\n",
    "\n",
    "print('Stage2: data over_sampling')\n",
    "smote = SMOTE(random_state=42, n_jobs=-1)\n",
    "x_train, y_train = smote.fit_resample(train, train_label)\n",
    "\n",
    "train_sp = get_fft_and_scaler(x_train, start=6892, end=7192)\n",
    "val_sp = get_fft_and_scaler(val, start=6892, end=7192)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "daad5551",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stage3: transform numpy data to tensor\n"
     ]
    }
   ],
   "source": [
    "# 将数据转换成pytorch的tensor\n",
    "print('Stage3: transform numpy data to tensor')\n",
    "batch_size = 128\n",
    "\n",
    "train_tensor = torch.tensor(train_sp).float()\n",
    "y_train_tensor = torch.tensor(y_train).long()\n",
    "val_tensor = torch.tensor(val_sp).float()\n",
    "y_val_tensor = torch.tensor(val_label).long()\n",
    "\n",
    "# 使用Dataloader对数据进行封装\n",
    "train_dataset = TensorDataset(train_tensor, y_train_tensor)\n",
    "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, drop_last=True)\n",
    "val_dataset = TensorDataset(val_tensor, y_val_tensor)\n",
    "val_loader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "06b4baa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.0001\n",
    "gamma = 0.9\n",
    "step_size = 1\n",
    "epochs = 15\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f114f833",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DNN().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7dfcac1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('./best_model0.point'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b4705198",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Score: 0.6086826475238217\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "preds = model(torch.FloatTensor(val_sp).to(device)).argmax(dim=1).cpu().numpy()\n",
    "score = (preds == val_label).sum()/len(val_label)\n",
    "print(f'Score: {score}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "16489469",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DNN(\n",
      "  (dnn): Sequential(\n",
      "    (0): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (1): Linear(in_features=300, out_features=1024, bias=True)\n",
      "    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (3): ReLU()\n",
      "    (4): Dropout(p=0.2, inplace=False)\n",
      "    (5): Linear(in_features=1024, out_features=256, bias=True)\n",
      "    (6): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (7): ReLU()\n",
      "    (8): Dropout(p=0.2, inplace=False)\n",
      "    (9): Linear(in_features=256, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "37e46f4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "activation = {}\n",
    "def get_activation(name):\n",
    "    def hook(model, input, output):\n",
    "        activation[name] = output.detach()\n",
    "    return hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "a719bbff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.utils.hooks.RemovableHandle at 0x7f4989079400>"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.dnn[9].register_forward_hook(get_activation(model.dnn[9]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "b04007f5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 4, 4, ..., 0, 8, 6])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(torch.FloatTensor(val_sp).to(device)).argmax(dim=1).cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "fdd1c01a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'9'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-64-363023a50d3e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mactivation\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'9'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m: '9'"
     ]
    }
   ],
   "source": [
    "activation['9'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "5f01016b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dnn\n",
      "Sequential(\n",
      "  (0): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (1): Linear(in_features=300, out_features=1024, bias=True)\n",
      "  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (3): ReLU()\n",
      "  (4): Dropout(p=0.2, inplace=False)\n",
      "  (5): Linear(in_features=1024, out_features=256, bias=True)\n",
      "  (6): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (7): ReLU()\n",
      "  (8): Dropout(p=0.2, inplace=False)\n",
      "  (9): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "for name, module in model.named_children():\n",
    "    print(name)\n",
    "    print(module)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09cb33cc",
   "metadata": {},
   "source": [
    "# 已知源识别"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9aad9dc9",
   "metadata": {},
   "source": [
    "流程：\n",
    "1. 构建训练集特征向量字典\n",
    "    - 因为训练集准确率没有到100，所以有些分类错误的向量不能加入进去\n",
    "        - 思路：将预测值和标签concate起来，变成11维度的向量，然后去除错误的\n",
    "2. 求出每个测试集的特征向量\n",
    "3. 测试集样本特征向量与训练集特征向量余弦相似度计算\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94a2fe00",
   "metadata": {},
   "source": [
    "##    训练集特征向量构建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 272,
   "id": "9154b81a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建训练集手机特征向量字典\n",
    "train_feature = {}\n",
    "#n_class = 10 # 10 phones\n",
    "#database = {str(i):[] for i in range(n_class)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "id": "c4734cb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "activation = {}\n",
    "def get_activation(name):\n",
    "    def hook(model, input, output):\n",
    "        activation[name] = output.detach()\n",
    "    return hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 274,
   "id": "a4f617b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, ..., 9, 9, 9])"
      ]
     },
     "execution_count": 274,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 抽取中间特征\n",
    "model = DNN().to(device)\n",
    "model.load_state_dict(torch.load('./best_model12.point'))\n",
    "model.eval()\n",
    "# 开始提取\n",
    "model.dnn[9].register_forward_hook(get_activation(model.dnn[9]))\n",
    "model(torch.FloatTensor(train_sp).to(device)).argmax(dim=1).cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 275,
   "id": "105b8a27",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.799615988828766"
      ]
     },
     "execution_count": 275,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(preds == y_train).sum() / len(y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 288,
   "id": "35c430ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = activation[model.dnn[9]].cpu().argmax(dim=1).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 289,
   "id": "ac2c72d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = np.concatenate([activation[model.dnn[9]].cpu(), np.expand_dims(preds, axis=1)], axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 293,
   "id": "6c69ded9",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_data = all_data[all_data[:, -1] == y_train, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 322,
   "id": "fe5b73bb",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# 获得分类正确的标签\n",
    "for i in range(n_class):\n",
    "    train_feature[str(i)] = torch.FloatTensor(true_data[true_data[:, -1]== i, :-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15610608",
   "metadata": {},
   "source": [
    "## 测试集的特征向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 323,
   "id": "9b7b3e1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "activation = {}\n",
    "def get_activation(name):\n",
    "    def hook(model, input, output):\n",
    "        activation[name] = output.detach()\n",
    "    return hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "id": "901f7a89",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 4, 8, ..., 0, 8, 6])"
      ]
     },
     "execution_count": 324,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()\n",
    "# 开始提取\n",
    "model.dnn[9].register_forward_hook(get_activation(model.dnn[9]))\n",
    "model(torch.FloatTensor(val_sp).to(device)).argmax(dim=1).cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 325,
   "id": "227fcc67",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "23403"
      ]
     },
     "execution_count": 325,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(val_sp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 326,
   "id": "62033f00",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_feature = activation[model.dnn[9]].cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 327,
   "id": "b8098cae",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([23403, 10])"
      ]
     },
     "execution_count": 327,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_feature.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1d19b6b",
   "metadata": {},
   "source": [
    "## 计算余弦相似度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 328,
   "id": "de1bcc88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([15573, 10])"
      ]
     },
     "execution_count": 328,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_feature['0'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 329,
   "id": "b46b7ade",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_similarity(source_tensor, object_tensor):\n",
    "    cos = nn.CosineSimilarity(dim=1, eps=1e-6)\n",
    "    output = cos(source_tensor.unsqueeze(dim=0), object_tensor)\n",
    "    scaler_output = output.sum()/object_tensor.size(0)\n",
    "    return scaler_output.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 330,
   "id": "c49b9490",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([15573, 10])"
      ]
     },
     "execution_count": 330,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_feature['0'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 331,
   "id": "ab11dec5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6705741286277771"
      ]
     },
     "execution_count": 331,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_similarity(val_feature[0], torch.FloatTensor(train_feature['0']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 332,
   "id": "5a923ae5",
   "metadata": {},
   "outputs": [],
   "source": [
    "similarity_list = []\n",
    "for i in range(n_class):\n",
    "    similarity = get_similarity(val_feature[0], train_feature[str(i)])\n",
    "    similarity_list.append(similarity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 333,
   "id": "05048474",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.6705741286277771,\n",
       " 0.8546786904335022,\n",
       " 0.2372012883424759,\n",
       " 0.752787709236145,\n",
       " 0.4196513593196869,\n",
       " 0.14867010712623596,\n",
       " 0.4547554552555084,\n",
       " 0.5406652688980103,\n",
       " 0.520589292049408,\n",
       " 0.2931250035762787]"
      ]
     },
     "execution_count": 333,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "similarity_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 334,
   "id": "c502258c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 335,
   "id": "aa9d661a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 23403/23403 [01:05<00:00, 357.02it/s]\n"
     ]
    }
   ],
   "source": [
    "preds_list = []\n",
    "for source_tensor in tqdm(val_feature):\n",
    "    similarity_list = []\n",
    "    for j in range(n_class):\n",
    "        similarity = get_similarity(source_tensor, train_feature[str(j)])\n",
    "        similarity_list.append(similarity)\n",
    "    ans = similarity_list.index(max(similarity_list))\n",
    "    preds_list.append(ans)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 336,
   "id": "b63563c8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(23403, 23403)"
      ]
     },
     "execution_count": 336,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(val_label), len(preds_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 337,
   "id": "d629a6d6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5819766696577362"
      ]
     },
     "execution_count": 337,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(np.array(preds_list) == val_label).sum()/len(val_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d8985d2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "165px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
